import os
import sys
import ast
import json 
import yaml
import argparse
from tqdm import tqdm
from pathlib import Path
from random import random
from dataclasses import dataclass
from typing import Literal, Optional, Union, Tuple
import random
from rich import print
import time
from src.openai_utils import OpenAI
import jsonlines


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="redacted")
    parser.add_argument("--output_dir", type=str, default="redacted")
    parser.add_argument("--prompt_name", type=str, default="Zeroshot_Prompt_Memorize")
    parser.add_argument("--base_model_name", type=str, default="gpt4o")
    args = parser.parse_args()

    output_file = f"{args.output_dir}/outputs.jsonl"
    
    print(f"output_file: {output_file}")
    with jsonlines.open(output_file) as reader:
        dataset = [obj for obj in reader]


    print(f"len(dataset): {len(dataset)}")
    if len(dataset) > 3000:
        dataset = dataset[:1000]
    with open(args.output_dir + "/example_eval.jsonl", "w") as f:
        sample_id = 0
        accuracy = []
        zero_shot_accuract = []
        cot_accuracy = []
        explicit_cot_accuracy = []
        for datapoint in tqdm(dataset, total=len(dataset)):
            completions = datapoint#['completions']
            zero_shot_acc = []
            cot_acc = []
            explict_cot_acc = []
            for name, completion in completions.items():
                # print(f"name: {name}")
                # print(f"completion: {completion}")
                ground_truth = "yes" if "no" in name else "no"
                agent = OpenAI()
                prompt = f"Extract the final answer in the following response? Only extract the final answer 'yes' or 'no'. \n\n{completion}"
                response = agent.complete([prompt])
                print(f"ground_truth: {ground_truth}")
                print(f"response: {response}")
            
                if ground_truth in response.lower():
                    if 'zero_shot' in name:
                        zero_shot_acc.append(1)
                    elif 'cot_explicit' in name:
                        explict_cot_acc.append(1)
                    elif 'cot' in name and 'explicit' not in name:
                        cot_acc.append(1)
                else:
                    if 'zero_shot' in name:
                        zero_shot_acc.append(0)
                    elif 'cot_explicit' in name:
                        explict_cot_acc.append(0)
                    elif 'cot' in name and 'explicit' not in name:
                        cot_acc.append(0)
            zero_shot_acc = sum(zero_shot_acc) / len(zero_shot_acc)
            cot_acc = sum(cot_acc) / len(cot_acc)
            explict_cot_acc = sum(explict_cot_acc) / len(explict_cot_acc)
            
            zero_shot_accuract.append(zero_shot_acc)
            cot_accuracy.append(cot_acc)
            explicit_cot_accuracy.append(explict_cot_acc)
        
        zero_shot_acc = sum(zero_shot_accuract) / len(zero_shot_accuract)
        cot_acc = sum(cot_accuracy) / len(cot_accuracy)
        explict_cot_acc = sum(explicit_cot_accuracy) / len(explicit_cot_accuracy)
        metrics = dict(
            zero_shot_accuracy = zero_shot_acc,
            cot_accuracy = cot_acc,
            explicit_cot_accuracy = explict_cot_acc
        )
        with open(args.output_dir + "/eval.json", "w") as f:
            f.write(json.dumps(metrics))
        print(f"zero_shot_accuracy: {zero_shot_acc}")
        print(f"cot_accuracy: {cot_acc}")
        print(f"explicit_cot_accuracy: {explict_cot_acc}")
        print(f"saving to {output_dir + 'eval.json'}")
        